import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from recbole_cdr.utils import get_model
import torch
import torch.nn.functional as F
import dgl
from dgl.nn import GraphConv, GATConv, SAGEConv
from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from generative_model.generative_layers import VAEReSample
class Predictor(nn.Module):
    def __init__(self, config):
        """
        Bias predictor.
        """
        super().__init__()
        self.hidden_size = config['embedding_size']
        self.embd_size = config['embedding_size']
        self.user_encode = nn.Sequential(
            nn.Linear(2 * self.embd_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size))
        self.item_encode = nn.Sequential(
            nn.Linear(2 * self.embd_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size))
        self.loss = nn.BCELoss()
        self.sigmoid=nn.Sigmoid()
    def forward(self,
                user_all_embeddings,
                item_all_embeddings,
                user_latent,
                item_latent
                ):
        u_embeddings = torch.cat((user_all_embeddings,
                                      user_latent),
                                      dim=-1)
        u_embeddings = self.user_encode(u_embeddings)
        i_embeddings = torch.cat((item_all_embeddings,
                                      item_latent),
                                      dim=-1)
        i_embeddings = self.item_encode(i_embeddings)
        return u_embeddings,i_embeddings
class EdgeDataset(Dataset):
    def __init__(self,
                 target_u: torch.Tensor,
                 target_i: torch.Tensor):
        assert target_u.size(0) == target_i.size(0), "target_u and target_i must have the same length"
        self.target_u = target_u
        self.target_i = target_i
    def __len__(self):
        return self.target_u.size(0)
    def __getitem__(self, idx):
        # 返回一个元组 (起点, 终点)
        return self.target_u[idx], self.target_i[idx]
def _get_topk_items4each_user(
        A: torch.Tensor,
        B: torch.Tensor,
        k: int,
        chunk_size: int = 1024):
    """
    在内存友好的前提下，获取 (A x B^T) 每行的 top-k 元素的坐标。
    通过分块计算，避免一次性生成 (a, b) 太大的中间结果。
    参数：
    -------
    A : torch.Tensor of shape (a, dim)
    B : torch.Tensor of shape (b, dim)
    k : 每行要获取的前 k 大元素
    chunk_size : 每次处理 B 的行数（即列块大小），可根据内存情况调节
    返回：
    -------
    top_coords : torch.Tensor of shape (a*k, 2)
                 每个坐标为 (row_index, col_index) 对，
                 其中 row_index 表示 A 中的行号，
                 col_index 表示对应于 B 的全局索引。
    """
    device = A.device
    a, _ = A.shape
    b, _ = B.shape
    # 用于存储全局 top-k 的分值和索引（索引指 B 中的行号，即乘积的列索引）
    global_top_vals = torch.full((a, k), float('-inf'), device=device)
    global_top_inds = torch.full((a, k), -1, dtype=torch.long, device=device)
    start = 0
    while start < b:
        end = min(start + chunk_size, b)
        # 取 B 的一部分
        B_chunk = B[start:end, :]  # shape: (chunk_len, dim)
        # 计算局部矩阵乘积 A x B_chunk^T，结果 shape: (a, chunk_len)
        partial_result = torch.matmul(A, B_chunk.transpose(0, 1))
        # 对每一行取局部 top-k
        partial_top_vals, partial_top_inds = partial_result.topk(k, dim=1)
        # 局部索引转换为全局索引
        partial_top_inds += start
        # 将局部的 top-k 与全局的候选合并
        merged_vals = torch.cat([global_top_vals, partial_top_vals], dim=1)  # (a, 2k)
        merged_inds = torch.cat([global_top_inds, partial_top_inds], dim=1)  # (a, 2k)
        merged_top_vals, merged_top_pos = merged_vals.topk(k, dim=1)
        # 根据 merged_top_pos 得到对应的全局索引
        row_idx = torch.arange(a, device=device).unsqueeze(1)
        final_inds = merged_inds[row_idx, merged_top_pos]
        # 更新全局候选
        global_top_vals = merged_top_vals
        global_top_inds = final_inds
        start = end
    # global_top_inds: (a, k) 表示每行在 B 中的全局索引
    # 为每个结果添加行号，构成 (row, col) 坐标对，形状 (a, k, 2)
    row_idx = torch.arange(a, device=device).unsqueeze(1).expand(a, k)
    top_coords = torch.stack([row_idx, global_top_inds], dim=2)
    # reshape 成 (a*k, 2)
    top_coords = top_coords.view(-1, 2)
    return top_coords
class Deleter_shapley(nn.Module):
    def __init__(self,
        config,
        source_u,
        source_i,
        target_u,
        target_i,
        total_num_users,
        total_num_items,
        overlapped_num_users,
        source_num_users,
        source_num_items,
        target_num_users,
        target_num_items,
        dataset):
        super().__init__()
        self.config = config
        self.device = config['device']
        self.total_num_users=total_num_users
        self.total_num_items=total_num_items
        self.overlapped_num_users=overlapped_num_users
        self.source_num_users=source_num_users
        self.source_num_items=source_num_items
        self.target_num_users=target_num_users
        self.target_num_items=target_num_items
        self.source_u = source_u.to(self.device)
        self.source_i = source_i.to(self.device)
        self.target_u = target_u.to(self.device)
        self.target_i = target_i.to(self.device)
        self.sigmoid = nn.Sigmoid()
        self.predictor=Predictor(config)
        self.edge_dataset = EdgeDataset(target_u,target_i)
        # 构造 DataLoader，每个 minibatch 包含 32 个边
        self.edge_dataloader = DataLoader(self.edge_dataset, batch_size=config['delete_batch'], shuffle=True)
        self.user_embedding = torch.nn.Embedding(num_embeddings=self.total_num_users,
                                                 embedding_dim=self.config['embedding_size'])
        self.item_embedding = torch.nn.Embedding(num_embeddings=self.total_num_items,
                                                 embedding_dim=self.config['embedding_size'])
        self.bpr_loss = BPRLoss()
        self.apply(xavier_normal_initialization)
        original_config=config
        original_config['delete']= False
        original_config['generate']= False
        self.trained_model=get_model('BPRLightGCN')(original_config, dataset)
        checkpoint = torch.load(self.config['load_model'])
        self.trained_model.load_state_dict(checkpoint['state_dict'])
        self.trained_model.eval()
        self.cfgraph=self.trained_model.merge_dgl_graph

        for param in self.trained_model.parameters():
            param.requires_grad = False
        self.self_model = get_model('BPRLightGCN')(original_config, dataset)
        checkpoint = torch.load(self.config['load_model'])
        self.self_model.load_state_dict(checkpoint['state_dict'])
        # encoder是一个多层的图神经网络
        if self.config['deleter_model']=='Graphconv':
           self.encoder = nn.ModuleList([
            GraphConv(self.config['embedding_size'],
                      self.config['embedding_size'],
                      norm='both',
                      weight=True,
                      bias=True,
                      allow_zero_in_degree=True) for _ in range(self.config['deleter_layers'])
                      ])
        elif self.config['deleter_model']=='GAT':
            self.encoder = nn.ModuleList([
                GATConv(self.config['embedding_size'],
                        self.config['embedding_size'],
                        num_heads=config['deleter_num_heads'],
                        allow_zero_in_degree=True) for _ in
                range(self.config['deleter_layers'])
            ])
        else:
            self.encoder = nn.ModuleList([
                GraphConv(self.config['embedding_size'],
                          self.config['embedding_size'],
                          norm='both',
                          weight=True,
                          bias=True,
                          allow_zero_in_degree=True) for _ in range(self.config['deleter_layers'])
            ])
        # VAE：将encode得到的embeddings转换成latents
        self.VAE = VAEReSample(self.config)
        # 每次的个batch是一一些想要mask掉的user的id
    def update_graph(self):
        all_users = torch.cat([self.target_u,self.source_u], dim=0)
        all_items = torch.cat([self.target_i,self.source_i], dim=0)
        num_users = self.total_num_users
        num_items = self.total_num_items
        all_items = all_items + num_users
        num_nodes_total = num_items + num_users
        ############################
        # 3. 构造 DGL图, 设置边权
        ############################
        g = dgl.graph((all_users, all_items), num_nodes=num_nodes_total).to('cpu')
        # 转换为无向图
        g = dgl.to_bidirected(g)
        if self.config['dglconv']!='LightGCN':
           g = dgl.add_self_loop(g)
        # 下面的权重换成上一轮的边的权重？
        edge_weights = torch.ones((g.num_edges(), 1), dtype=torch.float32)
        g.edata['w'] = edge_weights
        self.cfgraph=g.to(self.device)
        self.deletedgraph=g.to(self.device)
    def compute_score1(self, user_emb, item_emb):
        # 计算用户嵌入和项目嵌入的点积
        x = (user_emb * item_emb).sum(dim=1, keepdim=True)
        # 构建 logits，[x, 0] 表示正类和负类的得分
        logits = torch.cat([x, torch.zeros_like(x)], dim=1)
        # 使用 Gumbel-Softmax 生成概率分布
        gumbel_out = F.gumbel_softmax(logits, tau=self.config['delete_tau'], hard=False)
        # 提取选择 1 的概率作为得分
        scores = gumbel_out[:, 0]
        return scores.unsqueeze(1)
    def compute_score(self, user_emb, item_emb):
        # 计算用户嵌入和项目嵌入的点积
        x = (user_emb * item_emb).sum(dim=1, keepdim=True)
        scores=torch.sigmoid(x)
        return scores
    def find_edges_with_weight_one(self,graph,edge_weights, edge_indices):
        # edge_indices：source domain当中的边的index
        # 下面是去找到source边的权重
        selected_weights = edge_weights[edge_indices]
        mask = (selected_weights <self.config['delete_threshold']).squeeze()
        src, dst = graph.edges()
        selected_src = src[edge_indices]
        selected_dst = dst[edge_indices]
        src = selected_src[mask]
        dst = selected_dst[mask]
        # 只选取权重小于delete_threshold的边
        return src, dst

    def forward(self):
        # 1) 原始嵌入
        latent_embeddings = self.get_ego_embeddings()

        # 2) 根据 deleter_model 生成 outs
        if self.config['deleter_model'] != 'self':
            outs = latent_embeddings  # 初始
            for layer in self.encoder:
                outs = layer(self.cfgraph, outs)  # edge_weight 由 layer 内部自动处理度归一化
                if self.config['deleter_model'] == 'GAT':
                    outs = outs.mean(dim=1)  # 多头取均值
        else:
            user_emb, item_emb = self.self_model.forward(self.self_model.merge_dgl_graph)
            outs = torch.cat([user_emb, item_emb], dim=0)

        if self.config['deleter_model'] == 'VAE':
            usr_emb, itm_emb = torch.split(latent_embeddings,
                                           [self.total_num_users, self.total_num_items])
            usr_lat, itm_lat, usr_kl, itm_kl = self.VAE(True, usr_emb, itm_emb)
            usr_out, itm_out = self.predictor(usr_emb, itm_emb, usr_lat, itm_lat)
            outs = torch.cat([usr_out, itm_out], dim=0)

        # 3) 计算 source 边的“删除概率” score    (|S|,1)
        source_score = self.compute_score(
            outs[self.source_u],
            outs[self.source_i + self.total_num_users]
        ).squeeze()  # shape = (|S|,)

        # 4) 构造 **全图可微权重向量 edge_keep**
        E = self.cfgraph.num_edges()
        edge_keep = torch.ones(E, device=self.device)  # (E,)
#        print('original',edge_keep.mean())
        src_idx = self.cfgraph.edge_ids(self.trained_model.source_u.to(self.device),
                                        self.trained_model.source_i.to(self.device) + self.total_num_users)
        rev_idx = self.cfgraph.edge_ids(self.trained_model.source_i.to(self.device) + self.total_num_users,
                                        self.trained_model.source_u.to(self.device))
        all_idx = torch.cat([src_idx, rev_idx])
        keep_val = (1.0 - source_score).repeat(2)
        edge_keep = edge_keep.scatter(0, all_idx, keep_val)
        # 5) 用带梯度的 edge_keep 调用冻结模型
        cf_u_emb, cf_i_emb = self.trained_model.forward(
            self.cfgraph,
            weights=edge_keep
        )
        if self.config['deleter_model'] == 'VAE':
            return cf_u_emb, cf_i_emb, usr_kl, itm_kl
        else:
            return cf_u_emb, cf_i_emb,edge_keep


    def get_ego_embeddings(self):
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
        return ego_embeddings
    def calculate_loss(self, target_u, target_i):
        if self.config['deleter_model']=='VAE':
            cf_u_embeddings, cf_i_embeddings,deleted_u_embeddings, deleted_i_embeddings, user_kl_loss, item_kl_loss = self.forward()
        else:
            cf_u_embeddings, cf_i_embeddings,edge_keep = self.forward()
            self.edge_weights = edge_keep
        num_neg_samples = self.config['delete_neg_samples']
        neg_i = torch.randint(1,
                              self.target_num_items,
                              (target_u.shape[0],
                               num_neg_samples), device=self.device)
        # 反向 BPR 损失：鼓励CF中训练的embeddings满足 pos_score < neg_score
        cf_pos_scores = torch.mul(cf_u_embeddings[target_u], cf_i_embeddings[target_i]).sum(dim=1)
        cf_neg_scores = torch.mul(cf_u_embeddings[target_u].unsqueeze(1),cf_i_embeddings[neg_i]).sum(dim=2).mean(dim=1)
        # 让随机采样的item分高，让正样本的分低
        if self.config['deleter_model'] == 'VAE':
           cf_bpr_loss = self.bpr_loss(cf_neg_scores,cf_pos_scores)
        # print('cf_bpr_loss:',cf_bpr_loss)
           loss=cf_bpr_loss+user_kl_loss+item_kl_loss
           #print(cf_bpr_loss)
           return loss
        else:
            cf_bpr_loss = self.bpr_loss(cf_neg_scores, cf_pos_scores)
            sparsity = edge_keep.mean()
            # deleted_bpr_loss = self.bpr_loss(deleted_pos_scores, deleted_neg_scores)
            # print(cf_bpr_loss)
            # print(deleted_bpr_loss)
            loss=cf_bpr_loss+ self.config['delete_lambda_sparse'] * sparsity
            return loss
    def data_reproduce(self):
        # soruce domain与source items之间的interactions
        source_pre_edge_idx = self.cfgraph.edge_ids(self.source_u, self.source_i + self.total_num_users)
        # 下面是选出来的边
        new_source_u, new_source_i = self.find_edges_with_weight_one(self.cfgraph,self.edge_weights, source_pre_edge_idx)
        new_source_i -= self.total_num_users
        deleted = self.source_u.shape[0]-new_source_u.shape[0] # 原始图边数
        print(f"deleted: {deleted}")
        return new_source_u.to('cpu'), new_source_i.to('cpu')
def train2(config, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=config['delete_learning_rate'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    model.train()
    for epoch in range(config['delete_epochs']):
        epoch_loss = 0.0
        batch_count = 0
#        print(f"Epoch {epoch}")
        for batch_id, (batch_u, batch_i) in enumerate(model.edge_dataloader):

            #model.update_graph()
            optimizer.zero_grad()
            loss = model.calculate_loss(batch_u, batch_i)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            batch_count += 1
            #print(f"  Batch {batch_id} loss: {loss.item():.4f}")
        avg_loss = epoch_loss / batch_count
        scheduler.step(avg_loss)  # 根据平均损失调整学习率
#        print(f"Epoch {epoch} average loss: {avg_loss:.4f}")